在之前的课程中,我们重点学习了 逐元素操作 (例如矩阵上的基础 ReLU 操作)。这些操作属于 内存受限 因为 GPU 花费在将数据从高带宽内存(HBM)移动到寄存器上的时间,远多于执行数学计算的时间。
1. 为什么 GEMM 至关重要
通用矩阵乘法(GEMM)的计算复杂度为 $O(N^3)$,但仅需 $O(N^2)$ 的内存访问。这使我们能够利用巨大的算术吞吐量来隐藏内存延迟,因此它成为大语言模型的“核心心跳”。
2. 二维内存表示
物理内存是一维的。为了表示二维张量,我们使用 步幅(Strides)。一个常见的生产环境陷阱是 假设张量是连续的。如果你在指针计算中混淆了行与列的步幅,就会访问到‘幽灵’数据或引发内存违规。
3. 分块泛化
Triton 通过从 单个指针 转变为 指针块。通过使用二维分块(例如 $16 \times 16$),我们能充分利用高速缓存中的 数据复用 ,使数据在高速缓存中保持‘热态’,以便在写回全局内存前进行融合操作,如偏置加法或激活函数计算。
main.py
TERMINALbash — 80x24
> Ready. Click "Run" to execute.
>
QUESTION 1
Why is an elementwise ReLU on a large matrix considered 'memory-bound'?
The ReLU function requires complex transcendental math.
The ratio of arithmetic operations to memory loads is very low (1:1).
Matrices are naturally stored in CPU memory only.
Triton cannot process non-linear activations.
✅ Correct!
Correct! Because we perform only one operation per element loaded, the hardware spends most of its time waiting for the bus.❌ Incorrect
Arithmetic intensity is the ratio of work to memory access. Elementwise ops have very low intensity.QUESTION 2
What is the result of 'The Stride Trap' in production kernels?
The kernel runs significantly faster but with less precision.
Memory access violations or corrupted output due to incorrect address calculation on non-contiguous tensors.
The GPU automatically corrects the indexing using L2 cache.
The tensor is forced into a 1D shape by the compiler.
✅ Correct!
Yes. Assuming contiguity (stride=1) when a tensor is sliced or transposed leads to reading the wrong memory offsets.❌ Incorrect
Triton requires explicit stride handling; it won't 'guess' the layout if your math assumes contiguity.QUESTION 3
How does Triton represent a 2D tile of pointers?
By using a nested Python list of integers.
By broadcasting a 1D column vector and a 1D row vector of offsets together.
By launching multiple 1D kernels sequentially.
By allocating a special 2D register file.
✅ Correct!
Correct. `offs_m[:, None] + offs_n[None, :]` creates a 2D coordinate grid used for block loading.❌ Incorrect
Triton uses broadcasting to efficiently generate multidimensional pointer grids in a single program instance.QUESTION 4
Which operation benefits most from the O(N³) complexity shift to hide memory latency?
Vector Addition
Matrix Multiplication (GEMM)
Sigmoid Activation
Global Average Pooling
✅ Correct!
GEMM is compute-bound, meaning it does enough math to justify the cost of loading the data tiles.❌ Incorrect
The other options are O(N) or O(N²), which typically remain memory-bound.QUESTION 5
List three kernels in your current workflow that launch multiple PyTorch ops and might benefit from fusion.
Linear -> Bias -> ReLU; LayerNorm -> Dropout; Softmax -> Masking.
Print -> Log -> Sleep.
DataLoader -> Augmentation -> Storage.
These ops cannot be fused in Triton.
✅ Correct!
Reference Answer: 1. Linear -> ReLU (Common MLP block). 2. LayerNorm -> Dropout (Transformer residual). 3. Softmax -> Masking (Attention mechanism). Fusing these avoids intermediate HBM writes.❌ Incorrect
Look for sequences where a large tensor is modified by simple elementwise or reduction steps.Case Study: The Contiguity Crisis
Debugging non-contiguous tensor access in production
A developer writes a custom Triton kernel for a Linear Layer. On standard training data, it works perfectly. However, during inference, the input tensor is frequently 'sliced' (e.g., `x[:, :hidden_dim/2]`), which changes its stride without changing its memory layout. The kernel begins outputting 'NaN' and random noise.
Q
Why did the kernel fail when the input was sliced?
Solution:
Slicing usually creates a non-contiguous view. If the kernel assumed the row stride was equal to the number of columns (width), but the physical memory jump to the next row remained the original width, the kernel would read 'stale' data from the unsliced portion of memory.
Slicing usually creates a non-contiguous view. If the kernel assumed the row stride was equal to the number of columns (width), but the physical memory jump to the next row remained the original width, the kernel would read 'stale' data from the unsliced portion of memory.
Q
How should the pointer arithmetic be updated to handle this?
Solution:
The kernel must accept `stride_m` and `stride_n` as arguments. Instead of `ptr = base + i * width + j`, it must use `ptr = base + i * stride_m + j * stride_n` to respect the actual memory mapping.
The kernel must accept `stride_m` and `stride_n` as arguments. Instead of `ptr = base + i * width + j`, it must use `ptr = base + i * stride_m + j * stride_n` to respect the actual memory mapping.